www.gusucode.com > 有监督的 CNN 网络完成对MNIST 数字的识别 > 有监督的 CNN 网络完成对MNIST 数字的识别/CNN—卷积神经网络数字识别/@cnn/train.m

    function [cnet] = train(cnet,Ip,labels,I_testp, labels_test)

%TRAIN train convolutional neural network using stochastic Levenberg-Marquardt  
%
%  Syntax
%  
%    [cnet, perf_plot] = train(cnet,Ip,labels,I_testp, labtst)
%    
%  Description
%   Input:
%    cnet - Convolutional neural network class object
%    Ip - cell array, containing preprocessed images of handwriten digits
%    labels - cell array of labels, corresponding to images
%    I_testp - cell array, containing preprocessed images of handwriten
%    digits of test set
%    labtst - cell array of labels, corresponding to images of test set
%   Output:
%    cnet - trained convolutional neural network
%    perf_plot - performance data


%Initialize GUI
h_gui = cnn_gui();
%Progress bars
h_HessPatch = findobj(h_gui,'Tag','HessPatch');
h_HessEdit = findobj(h_gui,'Tag','HessPrEdit');
h_TrainPatch = findobj(h_gui,'Tag','TrainPatch');
h_TrainEdit = findobj(h_gui,'Tag','TrainPrEdit');
%Axes
h_MCRaxes = findobj(h_gui,'Tag','MCRaxes');
h_RMSEaxes = findobj(h_gui,'Tag','RMSEaxes');
%Info textboxes
h_EpEdit = findobj(h_gui,'Tag','EpEdit');
h_ItEdit = findobj(h_gui,'Tag','ItEdit');
h_RMSEedit = findobj(h_gui,'Tag','RMSEedit');
h_MCRedit = findobj(h_gui,'Tag','MCRedit');
h_TetaEdit = findobj(h_gui,'Tag','TetaEdit');
%Buttons
h_AbortButton = findobj(h_gui,'Tag','AbortButton');

tic;    %Fix the start time
perf_plot = []; %Array for storing performance data
%Coefficient, determining the running estimation of diagonal 
%Hessian approximation leak
gamma = 0.1;  

%Number of training patterns
numPats = length(Ip);
%Calculate the size of network
net_size = cnn_size(cnet);

ii = sparse(1:net_size,1:net_size,ones(1,net_size));    
jj = sparse(0);

%Initial MCR calculation
mcr(1)=calcMCR(cnet,I_testp, labels_test, 1:100);
plot(h_MCRaxes,mcr);
SetText(h_MCRedit,mcr(end));

if(cnet.HcalcMode == 1)
    for i=1:cnet.HrecalcSamplesNum
        %Setting the right output to 1, others to -1
        d = -ones(1,10);
        d(labels(i)+1) = 1;
        %Simulating
        [out, cnet] = sim(cnet,Ip{i});    
        %Calculate the error
        e = out-d;
        %Calculate Jacobian times error, or in other words calculate
        %gradient
        [cnet,je] = calcje(cnet,e); 
        [cnet,hx] = calchx(cnet);         
        jj = jj+diag(sparse(hx));
        SetHessianProgress(h_HessPatch,h_HessEdit,i/cnet.HrecalcSamplesNum);
    end
    %Averaging
    jj = jj/cnet.HrecalcSamplesNum;
end
%For all epochs
for t=1:cnet.epochs
    SetText(h_EpEdit,t);
    SetTextHP(h_TetaEdit,cnet.teta);
    %For all patterns
    for n=1:numPats
        %Setting the right output to 1, others to -1
        d = -ones(1,10);
        d(labels(n)+1) = 1;
        %Simulating
        [out, cnet] = sim(cnet,Ip{n});    
        %Calculate the error
        e = out-d;
        %Calculate Jacobian times error, or in other words calculate
        %gradient
        [cnet,je] = calcje(cnet,e); 
        %Calculate Hessian diagonal approximation
        if(cnet.HcalcMode == 0)
            [cnet,hx] = calchx(cnet);         
            %Calculate the running estimate of Hessian diagonal approximation
            jj = gamma*diag(sparse(hx))+sparse((1-gamma)*jj);     
        end
        if(cnet.HcalcMode == 1)
            if(mod(t*numPats+n,cnet.Hrecalc)==0) %If it is time to recalculate Hessian
                if(n+cnet.HrecalcSamplesNum>numPats)
                    stInd = numPats-cnet.HrecalcSamplesNum;
                else
                    stInd = n;
                end
                for i=stInd:stInd+cnet.HrecalcSamplesNum
                    %Setting the right output to 1, others to -1
                    d = -ones(1,10);
                    d(labels(i)+1) = 1;
                    %Simulating
                    [out, cnet] = sim(cnet,Ip{i});    
                    %Calculate the error
                    e = out-d;
                    %Calculate Jacobian times error, or in other words calculate
                    %gradient
                    [cnet,je] = calcje(cnet,e); 
                    [cnet,hx] = calchx(cnet);         
                    jj = jj+diag(sparse(hx));
                    
                    SetHessianProgress(h_HessPatch,h_HessEdit,(i-stInd)/cnet.HrecalcSamplesNum);
                end
                %Averaging
                jj = jj/cnet.HrecalcSamplesNum;
            end
        end


        %The following is usefull for debugging. 
%===========DEBUG
%        tmp(1)=check_finit_dif(cnet,1,Ip{n},d,1);
%===========DEBUG

        perf(n) = mse(e); %Store the error
        
        if(cnet.HcalcMode == 2) %Gradient descent
            dW = cnet.teta*je;
        else
            %Levenberg-Marquardt
            dW = (jj+cnet.mu*ii)\(cnet.teta*je);    
        end
        %Apply calculated weight updates
        cnet = adapt_dw(cnet,dW);
        
    %Plot mean of performance for every N patterns
    if(n>1)
              if(~mod(n-1,200))
                  mcr = [mcr calcMCR(cnet,I_testp, labels_test, 1:100)];
                  plot(h_MCRaxes,mcr);
                  SetText(h_MCRedit,mcr(end));
              end
              if(~mod(n-1,10))
                  perf_plot = [perf_plot,mean(sqrt(perf(n-10:n)))];         
                  plot(h_RMSEaxes,perf_plot);
                  SetText(h_RMSEedit,perf_plot(end));
              end
    end
              
        SetTrainingProgress(h_TrainPatch,h_TrainEdit,(n+(t-1)*numPats)/(numPats*cnet.epochs));
        SetText(h_ItEdit,n);
        drawnow;
        if(~isempty(get(h_AbortButton,'UserData')))
            fprintf('Training aborted \n');
            return;
        end
    end
    cnet.teta = cnet.teta*cnet.teta_dec;
end
toc

%Sets Hessian progress
%hp - handle of patch
%hs - handle of editbox
%pr - value from 0 to 1
function SetHessianProgress(hp,hs,pr)
xpatch = [0 pr*100 pr*100 0];
set(hp,'XData',xpatch);
set(hs,'String',[num2str(pr*100,'%5.2f'),'%']);
drawnow;


%Sets Training progress
%hp - handle of patch
%hs - handle of editbox
%pr - value from 0 to 1
function SetTrainingProgress(hp,hs,pr)
xpatch = [0 pr*100 pr*100 0];
set(hp,'XData',xpatch);
set(hs,'String',[num2str(pr*100,'%5.2f'),'%']);

%Set numeric text in the specified edit box
%hs - handle of textbox
%num - number to convert and set
function SetText(hs,num)
set(hs,'String',num2str(num,'%5.2f'));

%Set numeric text in the specified edit box with high preceition
%hs - handle of textbox
%num - number to convert and set
function SetTextHP(hs,num)
set(hs,'String',num2str(num,'%5.3e'));